package cn.sunxiansheng.openai.client;

import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import cn.sunxiansheng.openai.config.properties.OpenAiProperties;
import okhttp3.*;

import javax.annotation.Resource;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
 * Description: OpenAI 客户端类
 *
 * @Author sun
 * @Create 2024/12/14 11:56
 * @Version 1.1
 */
public class OpenAiClient {

    @Resource
    private OpenAiProperties openAiProperties;

    private static final OkHttpClient CLIENT = new OkHttpClient.Builder()
            .connectTimeout(120, TimeUnit.SECONDS)
            .readTimeout(120, TimeUnit.SECONDS)
            .writeTimeout(120, TimeUnit.SECONDS)
            .build();

    private static final Logger LOGGER = Logger.getLogger(OpenAiClient.class.getName());

    /**
     * 向 AI 提问通用方法
     *
     * @param model        使用的 AI 模型，如 "gpt-4o"
     * @param prompt       提示内容
     * @param base64Encode 是否对内容进行 Base64 编码
     * @return AI 的响应内容
     */
    public String askAI(String model, String prompt, boolean base64Encode) {
        try {
            // 处理 Base64 编码
            String encodedPrompt = base64Encode ? encodeBase64(prompt) : prompt;

            // 构造请求体
            RequestBody body = RequestBody.create(
                    createJsonRequest(model, encodedPrompt), MediaType.get("application/json; charset=utf-8")
            );

            // 构建请求
            Request request = new Request.Builder()
                    .url(openAiProperties.getApiUrl())
                    .header("Authorization", "Bearer " + openAiProperties.getApiKey())
                    .header("Content-Type", "application/json")
                    .post(body)
                    .build();

            // 发送请求并获取响应
            try (Response response = CLIENT.newCall(request).execute()) {
                if (!response.isSuccessful()) {
                    throw new IOException("Unexpected response: " + response);
                }

                // 解析 JSON 响应
                return parseResponse(response.body().string());
            }
        } catch (IOException e) {
            LOGGER.log(Level.SEVERE, "Error occurred during API request: " + e.getMessage(), e);
            throw new RuntimeException("API request failed", e);
        }
    }

    /**
     * 对输入内容进行 Base64 编码
     *
     * @param prompt 输入内容
     * @return 编码后的字符串
     */
    private String encodeBase64(String prompt) {
        return Base64.getEncoder().encodeToString(prompt.getBytes(StandardCharsets.UTF_8));
    }

    /**
     * 构建请求的 JSON 数据
     *
     * @param model         使用的 AI 模型
     * @param encodedPrompt 编码后的输入内容
     * @return 构建好的 JSON 字符串
     */
    private String createJsonRequest(String model, String encodedPrompt) {
        JsonObject jsonRequest = new JsonObject();
        jsonRequest.addProperty("model", model);

        JsonArray messages = new JsonArray();

        // 添加 system 信息
        JsonObject systemMessage = new JsonObject();
        systemMessage.addProperty("role", "system");
        systemMessage.addProperty("content", "请根据以下内容提供问题的解决方案。使用中文回答，使用markdown语法来回答问题，注意，内容可能经过 Base64 编码。");
        messages.add(systemMessage);

        // 添加 user 信息
        JsonObject userMessage = new JsonObject();
        userMessage.addProperty("role", "user");
        userMessage.addProperty("content", encodedPrompt);
        messages.add(userMessage);

        jsonRequest.add("messages", messages);

        return jsonRequest.toString();
    }

    /**
     * 解析 API 响应内容
     *
     * @param responseBody 响应的 JSON 内容
     * @return 解析后的结果
     */
    private String parseResponse(String responseBody) {
        JsonObject jsonObject = JsonParser.parseString(responseBody).getAsJsonObject();
        JsonArray choices = jsonObject.getAsJsonArray("choices");

        if (choices != null && choices.size() > 0) {
            JsonObject choice = choices.get(0).getAsJsonObject();
            JsonObject message = choice.getAsJsonObject("message");
            return message.get("content").getAsString();
        }

        throw new RuntimeException("Invalid response: No choices found.");
    }
}
